import numpy as np
import matplotlib.pyplot as plt
from scipy.integrate import solve_ivp
from tqdm import tqdm

# Set random seed for reproducibility
np.random.seed(42)

# Simulation parameters
T = 100.0  # Total simulation time
dt = 0.1   # Time step
t_span = (0, T)
t_eval = np.arange(0, T, dt)

# Dimensions
dim_hyletic = 5       # Dimension of raw sensorimotor (hyletic) data
dim_comp = 10         # Dimension of computational primitives
dim_lang = 15         # Dimension of linguistic representations

# System parameters
gamma_min = 0.1       # Minimum damping coefficient
noise_scale = 0.05    # Scale of the noise
alpha = 0.8           # Strong convexity parameter for potential functions

# Define potential functions (strictly convex)
def V_comp(P, H):
    """Potential function for computational primitives"""
    return 0.5 * alpha * np.sum((P - H @ W_HtoP) ** 2)

def grad_V_comp(P, H):
    """Gradient of potential function for computational primitives"""
    return alpha * (P - H @ W_HtoP)

def V_lang(L):
    """Potential function for linguistic representations"""
    return 0.5 * alpha * np.sum(L ** 2)  # Simple quadratic potential

def grad_V_lang(L):
    """Gradient of potential function for linguistic representations"""
    return alpha * L

# Initialize connection weights (these would typically be learned)
W_HtoP = np.random.normal(0, 1/np.sqrt(dim_hyletic), (dim_hyletic, dim_comp))
W_PtoL = np.random.normal(0, 1/np.sqrt(dim_comp), (dim_comp, dim_lang))

# Time-varying damping coefficient
def gamma(t):
    """Time-varying damping coefficient, always above gamma_min"""
    return gamma_min + 0.1 * np.sin(0.1 * t) ** 2

# Memories and emotions (simplified as functions of time)
def M(t):
    """Memory function"""
    return 0.2 * np.sin(0.05 * t) * np.ones(dim_comp)

def E(t):
    """Emotion function"""
    return 0.15 * np.cos(0.03 * t) * np.ones(dim_comp)

# Generate synthetic sensorimotor data
def generate_hyletic_data(t):
    """Generate synthetic sensorimotor data for demo purposes"""
    # Add some time-varying patterns
    base = 0.5 * np.sin(0.1 * t) + 0.3 * np.cos(0.2 * t)
    # Add some random variation
    return base * np.ones(dim_hyletic) + 0.1 * np.random.randn(dim_hyletic)

# Dynamics for Lemma 5: Computational Primitives
def comp_dynamics(t, P, H):
    """Dynamics of computational primitives as per Lemma 5"""
    dPdt = -grad_V_comp(P, H) + gamma_min + gamma_min * M(t) + gamma_min * E(t) + noise_scale * np.random.randn(dim_comp)
    return dPdt

# Dynamics for Lemma I.F: Linguistic Representations
def lang_dynamics(t, L, P_stable):
    """Dynamics of linguistic representations as per Lemma I.F"""
    # f(Rep_comp) = W · Rep_comp mapping computational to linguistic features
    f_P = P_stable @ W_PtoL
    
    # Additive noise
    noise = noise_scale * np.random.randn(dim_lang)
    
    # Full dynamics
    dLdt = -grad_V_lang(L) - gamma(t) * L + f_P + noise
    return dLdt

def run_simulation():
    """Run full simulation to demonstrate Lemma I.F"""
    print("\n=== SIMULATION OF Lemma I.F: STABILITY OF LINGUISTIC REPRESENTATIONS ===\n")
    print("This simulation demonstrates how linguistic representations stabilize when built upon computational primitives.")
    print("Parameters:")
    print(f"- Dimensions: Hyletic={dim_hyletic}, Computational={dim_comp}, Linguistic={dim_lang}")
    print(f"- Simulation time: {T} units with {dt} step size")
    print(f"- Minimum damping coefficient: {gamma_min}")
    print(f"- Noise scale: {noise_scale}")
    print(f"- Strong convexity parameter: {alpha}")
    
    # First, simulate computational dynamics to reach stable state P*
    print("\nSTEP 1: Computing stable computational primitives (P*) as per Lemma 5...")
    
    # Initialize computational representation
    P_initial = np.random.randn(dim_comp)
    
    # Generate fixed hyletic data for demonstrating convergence
    H_fixed = generate_hyletic_data(50)  # Data at an arbitrary time point
    
    # Define the ODE system for solve_ivp
    def comp_system(t, y):
        return comp_dynamics(t, y, H_fixed)
    
    # Solve the ODE system to find P*
    sol_comp = solve_ivp(comp_system, t_span, P_initial, t_eval=t_eval, method='RK45')
    
    # Extract the stable computational representation P*
    P_star = sol_comp.y[:, -1]
    
    print(f"Stable computational primitive P* reached with L2 norm: {np.linalg.norm(P_star):.4f}")
    print(f"Last 5 iterations show convergence: {np.linalg.norm(sol_comp.y[:, -5:] - P_star[:, np.newaxis], axis=0)}")
    
    # Now simulate linguistic dynamics using the stable P*
    print("\nSTEP 2: Simulating linguistic representation dynamics as per Lemma I.F...")
    
    # Initialize linguistic representation
    L_initial = np.random.randn(dim_lang)
    
    # Define the ODE system for solve_ivp
    def lang_system(t, y):
        return lang_dynamics(t, y, P_star)
    
    # Solve the ODE system
    sol_lang = solve_ivp(lang_system, t_span, L_initial, t_eval=t_eval, method='RK45')
    
    # Analyze convergence of linguistic representations
    L_final = sol_lang.y[:, -1]
    
    # Compute theoretical equilibrium point for linguistic representation
    # At equilibrium: ∇V(Rep_lang*) + γ(t)Rep_lang* = W · P*
    # For quadratic potential: αL* + γL* = W·P* => L* = W·P*/(α+γ)
    gamma_final = gamma(T)
    L_equilibrium = (P_star @ W_PtoL) / (alpha + gamma_final)
    
    # Calculate convergence metrics
    lang_distances = np.linalg.norm(sol_lang.y - L_equilibrium[:, np.newaxis], axis=0)
    
    # Verify stability condition from lemma
    convergence_rate = lang_distances[-1] / lang_distances[0]
    
    print(f"Linguistic representation L converged to equilibrium with L2 norm: {np.linalg.norm(L_final):.4f}")
    print(f"Theoretical equilibrium norm: {np.linalg.norm(L_equilibrium):.4f}")
    print(f"Distance to equilibrium: {np.linalg.norm(L_final - L_equilibrium):.6f}")
    print(f"Convergence rate: {convergence_rate:.6f} (closer to 0 means better convergence)")
    
    # Print convergence pattern for last few steps
    print("\nConvergence pattern for linguistic representation (distance to equilibrium):")
    for i in range(max(0, len(t_eval)-10), len(t_eval)):
        print(f"t={t_eval[i]:.1f}: {lang_distances[i]:.6f}")
    
    # Verify conditions from Lemma I.F
    print("\nVerifying conditions from Lemma I.F:")
    
    # 1. Convergence of computational primitives
    comp_converged = np.all(np.linalg.norm(sol_comp.y[:, -5:] - P_star[:, np.newaxis], axis=0) < 0.1)
    print(f"1. Computational primitives converged: {comp_converged}")
    
    # 2. Mapping f is continuous (linear in our case)
    print(f"2. Mapping f(Rep_comp) = W · Rep_comp is linear and continuous: True")
    
    # 3. Convergence under bounded perturbations
    noise_effect = noise_scale / (alpha + gamma_min)
    expected_fluctuation = noise_effect * np.sqrt(dim_lang)
    actual_fluctuation = np.std(lang_distances[-20:])
    print(f"3. Effect of bounded noise: Expected fluctuation ~{expected_fluctuation:.6f}, Actual ~{actual_fluctuation:.6f}")
    
    # 4. Stability criterion from extended derivation
    stability_term = alpha + gamma_min
    print(f"4. Stability criterion (α + γ > 0): {stability_term:.4f} > 0 - {stability_term > 0}")
    
    # Demonstrate how linguistic representation changes with changes in computational primitives
    print("\nSTEP 3: Demonstrating how linguistic representations adapt to changes in computational primitives...")
    
    # Generate slightly modified computational primitives
    P_modified = P_star + 0.2 * np.random.randn(dim_comp)
    
    print(f"Distance between original and modified P*: {np.linalg.norm(P_modified - P_star):.4f}")
    
    # Define the ODE system with modified P*
    def lang_system_modified(t, y):
        return lang_dynamics(t, y, P_modified)
    
    # Solve the ODE system with modified P*
    sol_lang_mod = solve_ivp(lang_system_modified, t_span, L_initial, t_eval=t_eval, method='RK45')
    
    # Calculate new equilibrium
    L_eq_modified = (P_modified @ W_PtoL) / (alpha + gamma_final)
    
    # Calculate distance to new equilibrium
    lang_mod_distances = np.linalg.norm(sol_lang_mod.y - L_eq_modified[:, np.newaxis], axis=0)
    
    # Compare final linguistic representations
    print(f"Modified linguistic equilibrium norm: {np.linalg.norm(L_eq_modified):.4f}")
    print(f"Distance to modified equilibrium: {lang_mod_distances[-1]:.6f}")
    print(f"Distance between original and modified linguistic equilibria: {np.linalg.norm(L_equilibrium - L_eq_modified):.4f}")
    
    # Conclusion
    print("\nCONCLUSION:")
    print("The simulation demonstrates the key claims of Lemma I.F:")
    print("1. Linguistic representations (Rep_lang) converge to stable states")
    print("2. These stable states depend continuously on the underlying computational primitives")
    print("3. The system remains stable under bounded perturbations and noise")
    print("4. The damping term γ(t) ensures dissipation of energy and convergence")
    print("\nThe results support the hierarchical structure where linguistic capabilities build upon")
    print("stable computational primitives, forming a system of nested attractors.")
    
    if convergence_rate < 0.1 and stability_term > 0 and comp_converged:
        print("\n✓ All conditions of Lemma I.F are satisfied and verified by simulation.")
    else:
        print("\n⚠ Some conditions of Lemma I.F may not be fully satisfied in this simulation.")

if __name__ == "__main__":
    run_simulation()